import logging
from utils.neo4j_provider import neo4j
import pandas as pd

logging.root.setLevel(logging.INFO)


# 关系去重函数
def deduplicate(relation_old) -> list:
    relation_new = []
    for each in relation_old:
        if each not in relation_new:
            relation_new.append(each)
    return relation_new


def generate_cql(start_node, end_node, edges, rel_type, rel_name) -> str:
    """
    生成 CQL
    """
    cql = []
    for edge in edges:
        p = edge[0]
        q = edge[1]
        # 创建关系的 Cypher 语句
        cql.append(
            "MATCH(p:%s),(q:%s) WHERE p.name='%s' and q.name='%s' CREATE (p)-[rel:%s{name:'%s'}]->(q)" % (start_node, end_node, p, q, rel_type, rel_name))
        print('创建关系 {}-{}->{}'.format(p, rel_type, q))
    return cql


def generate_relation(l_name, r_name) -> list:
    relation_list = []
    df = pd.read_csv('relationship_data.csv')
    for idx, row in df.iterrows():
        for l_node in row[l_name].split(','):
            for r_node in row[r_name].split(','):
                relation_list.append([l_node, r_node])
    return deduplicate(relation_list)


def create_relationship(l_node, r_node, relationship, l_data_name, r_data_name, relation_name):
    """
    创建关系
    :param l_node: 左节点 name
    :param r_node: 右节点 name
    :param relationship: 关系
    :param l_data_name: 左数据列名
    :param r_data_name: 右数据列名
    :param relation_name: 关系 name
    :return:
    """
    neo4j.delete_relationship(l_node, r_node, relationship)

    relation_list = generate_relation(l_data_name, r_data_name)
    print(relation_list)

    cql_list = generate_cql(l_node, r_node, relation_list, relationship, relation_name)
    for cql in cql_list:
        neo4j.execute_write(cql)
        print(cql)


def relationship_relation_check():
    l_node = "Symptom"
    r_node = "Examine"
    relationship = "need_check"
    l_data_name = '症状'
    r_data_name = '检查'
    rel_name = '症状检查'
    create_relationship(l_node, r_node, relationship, l_data_name, r_data_name, rel_name)


def relationship_has_symptom():
    l_node = "Disease"
    r_node = "Symptom"
    relationship = "has_symptom"
    l_data_name = '疾病'
    r_data_name = '症状'
    rel_name = '症状'
    create_relationship(l_node, r_node, relationship, l_data_name, r_data_name, rel_name)


def relationship_used_drugs():
    l_node = "Disease"
    r_node = "Drug"
    relationship = "used_drugs"
    l_data_name = '疾病'
    r_data_name = '药品'
    rel_name = '常用药品'
    create_relationship(l_node, r_node, relationship, l_data_name, r_data_name, rel_name)


def relationship_doeat_foods():
    l_node = "Disease"
    r_node = "Foods"
    relationship = "doeat_foods"
    l_data_name = '疾病'
    r_data_name = '宜吃'
    rel_name = '推荐食物'
    create_relationship(l_node, r_node, relationship, l_data_name, r_data_name, rel_name)


def relationship_noteat_foods():
    l_node = "Disease"
    r_node = "Foods"
    relationship = "noteat_foods"
    l_data_name = '疾病'
    r_data_name = '忌吃'
    rel_name = '忌吃食物'
    create_relationship(l_node, r_node, relationship, l_data_name, r_data_name, rel_name)


if __name__ == "__main__":
    # 有症状需要做哪些检查
    relationship_relation_check()

    # 疾病有哪些症状
    relationship_has_symptom()

    # 疾病常用药物
    relationship_used_drugs()

    # 推荐饮食
    relationship_doeat_foods()

    # 不宜饮食
    relationship_noteat_foods()
